from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnRewardThreshold, BaseCallback
import numpy as np

class CustomEval(EvalCallback):

    def __init__(self, 
                 eval_env,
                 callback_on_new_best=None,
                 callback_after_eval=None,
                 n_eval_episodes = 5,
                 eval_freq = 10000,
                 log_path = None,
                 best_model_save_path = None,
                 deterministic = True,
                 render = False,
                 verbose = 1,
                 warn = True,
                 warm_starts=0):
                 
        super().__init__(eval_env,
                         callback_on_new_best,
                         callback_after_eval,
                         n_eval_episodes,
                         eval_freq,
                         log_path,
                         best_model_save_path ,
                         deterministic,
                         render,
                         verbose,
                         warn)
        self.warm_starts = warm_starts

    def _on_step(self):

        if self.num_timesteps >= self.warm_starts:
            return super()._on_step()
        else:
            return True

class CustomEarlyStop(BaseCallback):
    """
    Stop the training once a threshold in episodic reward
    has been reached (i.e. when the model is good enough).

    It must be used with the ``EvalCallback``.

    :param reward_threshold:  Minimum expected reward per episode
        to stop training.
    :param verbose:
    """

    def __init__(self, reward_threshold, max_no_improvement_evals, min_evals, verbose = 0):
        super().__init__(verbose=verbose)
        self.reward_threshold = reward_threshold
        self.max_no_improvement_evals = max_no_improvement_evals
        self.no_improvement_evals = 0
        self.min_evals = min_evals
        self.last_best_mean_reward = -np.inf

    def _on_step(self) -> bool:

        assert self.parent is not None, "``CustomEarlyStop`` callback must be used " "with an ``EvalCallback``"
        # Convert np.bool_ to bool, otherwise callback() is False won't work

        continue_training = bool(self.parent.best_mean_reward < self.reward_threshold)

        if continue_training:
            if self.n_calls > self.min_evals:
                if self.parent.best_mean_reward > self.last_best_mean_reward:
                    self.no_improvement_evals = 0
                else:
                    self.no_improvement_evals += 1
                    if self.no_improvement_evals > self.max_no_improvement_evals:
                        continue_training = False

            self.last_best_mean_reward = self.parent.best_mean_reward
            
            if self.verbose > 0 and not continue_training:
                print(
                    f"Stopping training because there was no new best model in the last {self.no_improvement_evals:d} evaluations"
                )
        else:
            if self.verbose > 0:
                print(
                    f"Stopping training because the mean reward {self.parent.best_mean_reward:.2f} "
                    f" is above the threshold {self.reward_threshold}"
                )
        
        return continue_training